{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Interactive Predictions\n",
    "This notebook showcases the preprocessing pipeline of the `CodeTransformer` as well as predicting the method name for an arbitrary code snippet in one of the 5 languages (Java, Python, JavaScript, Ruby and Go) that we explored in the paper.  \n",
    "Once you downloaded the respective models and dataset files (we need the vocabularies and data configs for inference), and setup the paths in `env.py` you can load any model mentioned in the README and feed them with any code snippet to obtain a prediction for the method name."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%cd ..\n",
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from code_transformer.preprocessing.datamanager.preprocessed import CTPreprocessedDataManager\n",
    "from code_transformer.preprocessing.graph.binning import ExponentialBinning\n",
    "from code_transformer.preprocessing.graph.distances import PersonalizedPageRank, ShortestPaths, \\\n",
    "    AncestorShortestPaths, SiblingShortestPaths, DistanceBinning\n",
    "from code_transformer.preprocessing.graph.transform import DistancesTransformer\n",
    "from code_transformer.preprocessing.nlp.vocab import VocabularyTransformer, CodeSummarizationVocabularyTransformer\n",
    "from code_transformer.preprocessing.pipeline.stage1 import CTStage1Preprocessor\n",
    "from code_transformer.preprocessing.pipeline.stage2 import CTStage2MultiLanguageSample\n",
    "from code_transformer.utils.inference import get_model_manager, make_batch_from_sample, decode_predicted_tokens\n",
    "from code_transformer.env import DATA_PATH_STAGE_2\n",
    "\n",
    "%reload_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1. Load Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.1. Specify run ID\n",
    "All our models are listed in the [README](../README.md) together with their corresponding `run_id` as well as the stored snapshot. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_type = 'code_transformer'  # code_transformer, great or xl_net\n",
    "run_id = 'CT-6'  # Name of folder in which snapshots are stored\n",
    "snapshot = 'latest'  # Use 'latest' for the last stored snapshot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_manager = get_model_manager(model_type)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = model_manager.load_config(run_id)\n",
    "\n",
    "language = model_config['data_setup']['language']\n",
    "print(f\"Model was trained on: {language}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.2. Construct model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = model_manager.load_model(run_id, snapshot, gpu=False)\n",
    "model = model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2. Specify any code snippet\n",
    "Code snippet has to be in the target language and the method name to be predicted should be marked with `f`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "code_snippet = \"\"\"\n",
    "\"\"\"\n",
    "code_snippet_language = ''  # java, javascript, python, ruby, go"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2.1. Examples from Paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "code_snippet = \"\"\"public int f(Pair<LoggedJob, JobTraceReader> p1,Pair<LoggedJob, JobTraceReader> p2) {\n",
    "    LoggedJob j1 = p1.first();\n",
    "    LoggedJob j2 = p2.first();\n",
    "    return(j1.getSubmitTime() < j2.getSubmitTime()) ? -1 : (j1.getSubmitTime() == j2.getSubmitTime()) ? 0 : 1;\n",
    "}\"\"\"\n",
    "code_snippet_language = 'java'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "code_snippet = \"\"\"public static MNTPROC f(int value) {\n",
    "    if(value < 0 || value >= values().length) {\n",
    "        return null;\n",
    "    }\n",
    "    return values()[value];\n",
    "}\"\"\"\n",
    "code_snippet_language = 'java'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "code_snippet = \"\"\"private Iterable<ListBlobItem> f(String aPrefix, boolean useFlatBlobListing, EnumSet<BlobListingDetails> listingDetails, BlobRequestOptions options, OperationContext opContext) throws StorageException, URISyntaxException {\n",
    "    CloudBlobDirectoryWrapper directory = this.container.getDirectoryReference(aPrefix);\n",
    "    return directory.listBlobs(null, useFlatBlobListing, listingDetails, options, opContext);\n",
    "}\"\"\"\n",
    "code_snippet_language = 'java'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "code_snippet = \"\"\"private static void f(EnumMap<FSEditLogOpCodes, Holder<Integer>> opCounts) {\n",
    "    StringBuilder sb = newStringBuilder();\n",
    "    sb.append(\"Summary of operations loaded from edit log:  \");\n",
    "    Joiner.on(\"  \").withKeyValueSeparator(\"=\").appendTo(sb, opCounts);\n",
    "    FSImage.LOG.debug(sb.toString());\n",
    "}\"\"\"\n",
    "code_snippet_language = 'java'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "code_snippet = \"\"\"static String f(File f, String... cmd) throws IOException {\n",
    "    String[] args = new String[cmd.length + 1];\n",
    "    System.arraycopy(cmd, 0, args, 0, cmd.length);\n",
    "    args[cmd.length] = f.getCanonicalPath();\n",
    "    String output = Shell.execCommand(args);\n",
    "    return output;\n",
    "}\"\"\"\n",
    "code_snippet_language = 'java'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "code_snippet = \"\"\"protected void f(Class<? extends SubView> cls) {\n",
    "    indent(of(ENDTAG));\n",
    "    sb.setLength(0);\n",
    "    out.print(sb.append('[').append(cls.getName()).append(']').toString());\n",
    "    out.println();\n",
    "}\"\"\"\n",
    "code_snippet_language = 'java'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "code_snippet = \"\"\"\n",
    "function f() {\n",
    "    var quotes = new Array();\n",
    "    quotes[0] = \"Action is the real measure of intelligence.\";\n",
    "    quotes[1] = \"Baseball has the great advantage over cricket of being sooner ended.\";\n",
    "    quotes[2] = \"Every goal, every action, every thought, every feeling one experiences, whether it be consciously or unconsciously known, is an attempt to increase one's level of peace of mind.\";\n",
    "    quotes[3] = \"A good head and a good heart are always a formidable combination.\";\n",
    "    var rand = Math.floor(Math.random()*quotes.length);\n",
    "    document.write(quotes[rand]);\n",
    "}\n",
    "\"\"\"\n",
    "code_snippet_language = 'javascript'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3. Preprocess"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.1. Stage 1 (AST generation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocessor = CTStage1Preprocessor(code_snippet_language, allow_empty_methods=True)\n",
    "stage1_sample = preprocessor.process([(\"f\", \"\", code_snippet)], 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.2. Stage 2 (Distance matrices)\n",
    "We have to mimic the preprocessing to match exactly what the model has been trained on. To this end, we make use of the respective dataset config that was stored during preprocessing. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the config of the respective dataset that this model was trained on\n",
    "model_language = model_config['data_setup']['language']\n",
    "data_manager = CTPreprocessedDataManager(DATA_PATH_STAGE_2, model_language, partition='train', shuffle=True)\n",
    "data_config = data_manager.load_config()\n",
    "\n",
    "# Extract how distances should be computed from the dataset config\n",
    "distances_config = data_config['distances']\n",
    "PPR_ALPHA = distances_config['ppr_alpha']\n",
    "PPR_USE_LOG = distances_config['ppr_use_log']\n",
    "PPR_THRESHOLD = distances_config['ppr_threshold']\n",
    "\n",
    "SP_THRESHOLD = distances_config['sp_threshold']\n",
    "\n",
    "ANCESTOR_SP_FORWARD = distances_config['ancestor_sp_forward']\n",
    "ANCESTOR_SP_BACKWARD = distances_config['ancestor_sp_backward']\n",
    "ANCESTOR_SP_NEGATIVE_REVERSE_DISTS = distances_config['ancestor_sp_negative_reverse_dists']\n",
    "ANCESTOR_SP_THRESHOLD = distances_config['ancestor_sp_threshold']\n",
    "\n",
    "SIBLING_SP_FORWARD = distances_config['sibling_sp_forward']\n",
    "SIBLING_SP_BACKWARD = distances_config['sibling_sp_backward']\n",
    "SIBLING_SP_NEGATIVE_REVERSE_DISTS = distances_config['sibling_sp_negative_reverse_dists']\n",
    "SIBLING_SP_THRESHOLD = distances_config['sibling_sp_threshold']\n",
    "\n",
    "# Extract how distances should be binned from the dataset config\n",
    "binning_config = data_config['binning']\n",
    "EXPONENTIAL_BINNING_GROWTH_FACTOR = binning_config['exponential_binning_growth_factor']\n",
    "N_FIXED_BINS = binning_config['n_fixed_bins']\n",
    "NUM_BINS = binning_config['num_bins']\n",
    "\n",
    "preprocessing_config = data_config['preprocessing']\n",
    "REMOVE_PUNCTUATION = preprocessing_config['remove_punctuation']\n",
    "\n",
    "# Put together all the implementations of the different distance metrics\n",
    "distance_metrics = [\n",
    "    PersonalizedPageRank(threshold=PPR_THRESHOLD, log=PPR_USE_LOG, alpha=PPR_ALPHA),\n",
    "    ShortestPaths(threshold=SP_THRESHOLD),\n",
    "    AncestorShortestPaths(forward=ANCESTOR_SP_FORWARD, backward=ANCESTOR_SP_BACKWARD,\n",
    "                          negative_reverse_dists=ANCESTOR_SP_NEGATIVE_REVERSE_DISTS,\n",
    "                          threshold=ANCESTOR_SP_THRESHOLD),\n",
    "    SiblingShortestPaths(forward=SIBLING_SP_FORWARD, backward=SIBLING_SP_BACKWARD,\n",
    "                         negative_reverse_dists=SIBLING_SP_NEGATIVE_REVERSE_DISTS,\n",
    "                         threshold=SIBLING_SP_THRESHOLD)]\n",
    "\n",
    "db = DistanceBinning(NUM_BINS, N_FIXED_BINS, ExponentialBinning(EXPONENTIAL_BINNING_GROWTH_FACTOR))\n",
    "\n",
    "distances_transformer = DistancesTransformer(distance_metrics, db)\n",
    "vocabs = data_manager.load_vocabularies()\n",
    "if len(vocabs) == 4:\n",
    "    vocabulary_transformer = CodeSummarizationVocabularyTransformer(*vocabs)\n",
    "else:\n",
    "    vocabulary_transformer = VocabularyTransformer(*vocabs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now, take the result of stage1 preprocessing and feed it through the vocabulary and distances transformer to obtain a stage2 sample\n",
    "\n",
    "stage2_sample = stage1_sample[0]\n",
    "if REMOVE_PUNCTUATION:\n",
    "    stage2_sample.remove_punctuation()\n",
    "stage2_sample = vocabulary_transformer(stage2_sample)\n",
    "stage2_sample = distances_transformer(stage2_sample)\n",
    "\n",
    "if ',' in model_language:\n",
    "    # In the multi-lingual setting, we have to furthermore bake the code snippet language into the sample\n",
    "    stage2_sample = CTStage2MultiLanguageSample(stage2_sample.tokens, stage2_sample.graph_sample, stage2_sample.token_mapping,\n",
    "                                                stage2_sample.stripped_code_snippet, stage2_sample.func_name,\n",
    "                                                stage2_sample.docstring,\n",
    "                                                code_snippet_language,\n",
    "                                                encoded_func_name=stage2_sample.encoded_func_name if hasattr(stage2_sample, 'encoded_func_name') else None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3.3. Prepare sample to feed into model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = make_batch_from_sample(stage2_sample, model_config, model_type)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4. Prediction from model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "output = model.forward_batch(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = 3\n",
    "predictions = output.logits \\\n",
    "    .topk(k, axis=-1)\\\n",
    "    .indices\\\n",
    "    .squeeze()\\\n",
    "    .T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Predicted method names:')\n",
    "for i, prediction in enumerate(predictions):\n",
    "    predicted_method_name = decode_predicted_tokens(prediction, batch, data_manager)\n",
    "    print(f\"  ({i + 1}) \", ' '.join(predicted_method_name))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.1. Code Snippet embedding\n",
    "In order to obtain a meaningful embedding of the provided AST/Source code pair, one can use the Query Stream Embedding of the masked method name token in the final encoder layer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder_output = model.lm_encoder.forward_batch(batch, need_all_embeddings=True)\n",
    "query_stream_embedding = encoder_output.all_emb[-1][1]  # [1, B, D]"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:code-parser]",
   "language": "python",
   "name": "conda-env-code-parser-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
